/**
 * Copyright 2019-2022 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "graph/model.h"
#include "securec.h"
#include "framework/graph/core/cgraph/compute_graph.h"
#include "framework/graph/core/cgraph/graph_serializer.h"
#include "framework/graph/utils/graph_utils.h"
#include "framework/graph/debug/ge_log.h"
#include "framework/common/types.h"


// src/framework/inc
#include "infra/base/assertion.h"

#include "graph/persistance/interface/model_def.h"
#include "graph/persistance/interface/graph_def.h"
#include "graph/persistance/proxy/proto_factory.h"

using namespace hiai;
namespace ge {
const static string STR_EMPTY = "";

Model::Model(const string& name, const string& customVersion) : Model()
{
    if (modelDef_ != nullptr) {
        modelDef_->set_name(name);
        modelDef_->set_custom_version(customVersion);
    }
}

Model::Model() : Model(hiai::ProtoFactory::Instance()->CreateModelDef())
{
}

Model::Model(hiai::IModelDef* modelDef) : modelDef_(modelDef), weightBuffer_(0)
{
}

Model::~Model()
{
    if (modelDef_ != nullptr) {
        hiai::ProtoFactory::Instance()->DestroyModelDef(modelDef_);
    }
    modelDef_ = nullptr;
}

const hiai::IAttrMapDef* Model::GetAttrMapDef() const
{
    return modelDef_ != nullptr ? modelDef_->attr() : nullptr;
}

hiai::IAttrMapDef* Model::MutableAttrMapDef()
{
    return modelDef_ != nullptr ? modelDef_->mutable_attr() : nullptr;
}

const string& Model::GetName() const
{
    return modelDef_ != nullptr ? modelDef_->name() : STR_EMPTY;
}

void Model::SetName(const string& name)
{
    if (modelDef_ != nullptr) {
        modelDef_->set_name(name);
    }
}

uint32_t Model::GetVersion() const
{
    return modelDef_ != nullptr ? static_cast<uint32_t>(modelDef_->version()) : 0;
}

void Model::SetVersion(uint32_t version)
{
    if (modelDef_ != nullptr) {
        modelDef_->set_version(version);
    }
}

const string& Model::GetPlatformVersion() const
{
    return modelDef_ != nullptr ? modelDef_->custom_version() : STR_EMPTY;
}

void Model::SetPlatformVersion(const std::string& version)
{
    if (modelDef_ != nullptr) {
        modelDef_->set_custom_version(version);
    }
}

void Model::SetGraph(const ge::Graph& graph)
{
    graph_ = graph;
}

Graph Model::GetGraph() const
{
    return graph_;
}

GraphErrCodeStatus Model::SerializeTo(hiai::IModelDef* modelDef) const
{
    HIAI_EXPECT_NOT_NULL(modelDef);
    HIAI_EXPECT_NOT_NULL(modelDef_);
    modelDef->set_name(modelDef_->name());
    modelDef->set_version(modelDef_->version());
    modelDef->set_custom_version(modelDef_->custom_version());
    modelDef->set_attr(modelDef_->mutable_attr());

    auto computeGraph = GraphUtils::GetComputeGraph(graph_);
    HIAI_EXPECT_NOT_NULL(computeGraph);

    HIAI_EXPECT_TRUE(computeGraph->ROLE(GraphSerializer).SerializeTo(modelDef->add_graph()));

    return GRAPH_SUCCESS;
}

static void BuildPartitionTable(ModelPartitionTable* partitionTable, uint32_t partitionsNum,
    size_t graphBufferSize, size_t weightBufferSize)
{
    int partitionOrder = 0;
    uint32_t memOffset = 0;

    partitionTable->num = partitionsNum;
    partitionTable->partition[partitionOrder] = {
        ModelPartitionType::MODEL_DEF, memOffset, static_cast<uint32_t>(graphBufferSize)
    };

    if (partitionsNum > 1) {
        memOffset += graphBufferSize;
        partitionOrder += 1;
        partitionTable->partition[partitionOrder] = {
            ModelPartitionType::WEIGHTS_DATA, memOffset, static_cast<uint32_t>(weightBufferSize)
        };
    }
}

static Status FillModelHeaderParams(uint8_t* basePtr, size_t totalSize)
{
    (void)memset_s(basePtr, sizeof(ModelFileHeader), 0, sizeof(ModelFileHeader));

    std::string defaultModelName("default");
    ModelFileHeader* modelHeader = reinterpret_cast<ModelFileHeader*>(basePtr);
    HIAI_EXPECT_TRUE(memcpy_s(modelHeader->name, MODEL_NAME_LENGTH,
        defaultModelName.c_str(), defaultModelName.size()) == EOK);

    // need to compile graph
    uint32_t modelFileHeadLen = 256;
    uint32_t modelVersion = 0x10000000;
    modelHeader->modeltype = ModelType::IR_API_GRAPH_MODEL;
    modelHeader->magic = MODEL_FILE_MAGIC_NUM;
    modelHeader->headsize = modelFileHeadLen;
    modelHeader->version = modelVersion;
    modelHeader->length = totalSize;
    return SUCCESS;
}

static bool BuildWeightMergedModelBuffer(IModelDef* modelDef, const Buffer& weightBuffer, Buffer& outBuffer)
{
    size_t graphBufferSize = modelDef->GetModelDefSize();
#if defined(AI_SUPPORT_32_BIT_OS)
    // 4 Byte align for 32 bit os
    if (graphBufferSize % 4 != 0) {
        graphBufferSize = (((graphBufferSize + 3) / 4) * 4);
    }
#endif
    size_t weightBufferSize = weightBuffer.GetSize();
    size_t partitionsNum = 1;
    if (weightBufferSize != 0) {
        partitionsNum += 1;
    }
    size_t partitionTableSize = sizeof(ModelPartitionTable) + sizeof(ModelPartitionMemInfo) * partitionsNum;
    size_t totalBufferSize = sizeof(ModelFileHeader) + partitionTableSize + graphBufferSize + weightBufferSize;

    outBuffer.Resize(totalBufferSize);
    auto clearBufferFunc = [&outBuffer](int*) {
        outBuffer.Clear();
    };
    std::unique_ptr<int, decltype(clearBufferFunc)> decl(nullptr, clearBufferFunc);
    uint8_t* basePtr = outBuffer.MutableData();
    HIAI_EXPECT_NOT_NULL_R(basePtr, false);

    size_t offset = 0;
    // 1. Build and make modelHeader
    HIAI_EXPECT_EXEC_R(FillModelHeaderParams(basePtr, partitionTableSize + graphBufferSize + weightBufferSize), false);
    offset += sizeof(ModelFileHeader);

    // 2. Build and make partitionTable
    BuildPartitionTable(reinterpret_cast<ModelPartitionTable*>(basePtr + offset),
        partitionsNum, graphBufferSize, weightBufferSize);
    offset += partitionTableSize;

    // 3. Direct serilize graphBuffer to outBuffer
    if (!modelDef->SaveTo(basePtr + offset, graphBufferSize)) {
        return false;
    }
    offset += graphBufferSize;

    if (partitionsNum > 1) {
        // 4. Copy weightBuffer to outBuffer
        HIAI_EXPECT_TRUE_R(memcpy_s(reinterpret_cast<void*>(basePtr + offset), totalBufferSize - offset,
            weightBuffer.GetData(), weightBuffer.GetSize()) == EOK, false);
    }

    FMK_LOGI("Build success!");
    decl.release();
    return true;
}

GraphErrCodeStatus Model::Save(Buffer& buffer) const
{
    auto modelDef = hiai::ProtoFactory::Instance()->CreateModelDef();
    HIAI_EXPECT_NOT_NULL(modelDef);

    if (SerializeTo(modelDef) != GRAPH_SUCCESS) {
        hiai::ProtoFactory::Instance()->DestroyModelDef(modelDef);
        return GRAPH_FAILED;
    }

    buffer.Resize(modelDef->GetModelDefSize());

    if (!modelDef->SaveTo(buffer.MutableData(), buffer.GetSize())) {
        buffer.Clear();
        hiai::ProtoFactory::Instance()->DestroyModelDef(modelDef);
        return GRAPH_FAILED;
    }

    hiai::ProtoFactory::Instance()->DestroyModelDef(modelDef);
    return GRAPH_SUCCESS;
}

GraphErrCodeStatus Model::SaveFullModel(Buffer& buffer) const
{
    auto modelDef = hiai::ProtoFactory::Instance()->CreateModelDef();
    HIAI_EXPECT_NOT_NULL(modelDef);

    if (SerializeTo(modelDef) != GRAPH_SUCCESS) {
        hiai::ProtoFactory::Instance()->DestroyModelDef(modelDef);
        return GRAPH_FAILED;
    }

    bool ret = BuildWeightMergedModelBuffer(modelDef, weightBuffer_, buffer);
    hiai::ProtoFactory::Instance()->DestroyModelDef(modelDef);
    HIAI_EXPECT_TRUE(ret);

    return GRAPH_SUCCESS;
}

GraphErrCodeStatus Model::Load(const uint8_t* data, size_t len)
{
    HIAI_EXPECT_NOT_NULL(modelDef_);
    HIAI_EXPECT_TRUE(modelDef_->LoadFrom(data, len));
    HIAI_EXPECT_TRUE(modelDef_->graph_size() > 0);

    hiai::IGraphDef* graphDef = hiai::ProtoFactory::Instance()->CreateGraphDef();
    HIAI_EXPECT_NOT_NULL(graphDef);

    if (!graphDef->Swap(modelDef_->mutable_graph(0))) {
        hiai::ProtoFactory::Instance()->DestroyGraphDef(graphDef);
        return GRAPH_FAILED;
    }

    auto computeGraph = ComputeGraph::Make(graphDef, true);
    HIAI_EXPECT_NOT_NULL(computeGraph);

    HIAI_EXPECT_TRUE(computeGraph->ROLE(GraphSerializer).UnSerialize());

    graph_ = GraphUtils::CreateGraphFromComputeGraph(computeGraph);
    return GRAPH_SUCCESS;
}

GraphErrCodeStatus Model::Dump(const string& outFile)
{
    auto modelDef = hiai::ProtoFactory::Instance()->CreateModelDef();
    HIAI_EXPECT_NOT_NULL(modelDef);

    if (SerializeTo(modelDef) != GRAPH_SUCCESS) {
        hiai::ProtoFactory::Instance()->DestroyModelDef(modelDef);
        return GRAPH_FAILED;
    }

    if (!modelDef->Dump(outFile)) {
        hiai::ProtoFactory::Instance()->DestroyModelDef(modelDef);
        return GRAPH_FAILED;
    }

    hiai::ProtoFactory::Instance()->DestroyModelDef(modelDef);
    return GRAPH_SUCCESS;
}

bool Model::IsValid() const
{
    return graph_.IsValid();
}

Buffer* Model::CreateWeightBuffer(size_t size)
{
    if (size != weightBuffer_.GetSize()) {
        weightBuffer_.Resize(size);
    }
    return &weightBuffer_;
}
} // namespace ge
